{ "cells": [ { "cell_type": "markdown", "id": "817a5295", "metadata": {}, "source": [ "# Introduction to JumpStart - Text Summarization" ] }, { "attachments": {}, "cell_type": "markdown", "id": "cf4510e1", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "3f27d650", "metadata": {}, "source": [ "---\n", "Welcome to Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html)! You can use JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker JumpStart API](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart). \n", "\n", "In this demo notebook, we demonstrate how to use the JumpStart API for Text Summarization. Text Summarization is the task of shortening the data and creating a summary that represents the most important information present in the original text. Here, we show how to use state-of-the-art pre-trained Distilbart model for Text Summarization. \n", "\n", "---" ] }, { "cell_type": "markdown", "id": "61cb083d", "metadata": {}, "source": [ "1. [Set Up](#1.-Set-Up)\n", "2. [Select a model](#2.-Select-a-model)\n", "3. [Retrieve JumpStart Artifacts & Deploy an Endpoint](#3.-Retrieve-JumpStart-Artifacts-&-Deploy-an-Endpoint)\n", "4. [Query endpoint and parse response](#4.-Query-endpoint-and-parse-response)\n", "5. [Clean up the endpoint](#5.-Clean-up-the-endpoint)" ] }, { "cell_type": "markdown", "id": "c4405394", "metadata": {}, "source": [ "Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel." ] }, { "cell_type": "markdown", "id": "ef2340f2", "metadata": {}, "source": [ "### 1. Set Up" ] }, { "cell_type": "markdown", "id": "1cd9d133", "metadata": {}, "source": [ "---\n", "Before executing the notebook, there are some initial steps required for set up. This notebook requires latest version of sagemaker and ipywidgets\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "778170fd", "metadata": {}, "outputs": [], "source": [ "!pip install 'sagemaker<3.0' ipywidgets --upgrade --quiet" ] }, { "cell_type": "markdown", "id": "3a217693", "metadata": {}, "source": [ "#### Permissions and environment variables\n", "\n", "---\n", "To host on Amazon SageMaker, we need to set up and authenticate the use of AWS services. Here, we use the execution role associated with the current notebook as the AWS account role with SageMaker access. \n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "e478668f", "metadata": {}, "outputs": [], "source": [ "import sagemaker, boto3, json\n", "from sagemaker import get_execution_role\n", "\n", "aws_role = get_execution_role()\n", "aws_region = boto3.Session().region_name\n", "sess = sagemaker.Session()" ] }, { "cell_type": "markdown", "id": "851878c4", "metadata": {}, "source": [ "### 2. Select a model\n", "\n", "***\n", "Here, we download jumpstart model_manifest file from the jumpstart s3 bucket, filter-out all the Text Summarization models and select a model for inference. \n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "00398667", "metadata": {}, "outputs": [], "source": [ "import ipywidgets as widgets\n", "\n", "# download JumpStart model_manifest file.\n", "boto3.client(\"s3\").download_file(\n", " f\"jumpstart-cache-prod-{aws_region}\", \"models_manifest.json\", \"models_manifest.json\"\n", ")\n", "with open(\"models_manifest.json\", \"rb\") as json_file:\n", " model_list = json.load(json_file)\n", "\n", "# filter-out all the Text Summarization models from the manifest list.\n", "text_summarization_models = []\n", "for model in model_list:\n", " model_id = model[\"model_id\"]\n", " if \"-summarization-\" in model_id and model_id not in text_summarization_models:\n", " text_summarization_models.append(model_id)\n", "\n", "# display the model-ids in a dropdown to select a model for inference.\n", "model_dropdown = widgets.Dropdown(\n", " options=text_summarization_models,\n", " value=\"huggingface-summarization-distilbart-cnn-6-6\",\n", " description=\"Select a model\",\n", " style={\"description_width\": \"initial\"},\n", " layout={\"width\": \"max-content\"},\n", ")" ] }, { "cell_type": "markdown", "id": "3c06c67b", "metadata": {}, "source": [ "#### Chose a model for Inference" ] }, { "cell_type": "code", "execution_count": null, "id": "30160d1b", "metadata": {}, "outputs": [], "source": [ "display(model_dropdown)" ] }, { "cell_type": "code", "execution_count": null, "id": "b2dff735", "metadata": {}, "outputs": [], "source": [ "# model_version=\"*\" fetches the latest version of the model\n", "model_id, model_version = model_dropdown.value, \"*\"" ] }, { "cell_type": "markdown", "id": "16ffaf23", "metadata": {}, "source": [ "### 3. Retrieve JumpStart Artifacts & Deploy an Endpoint\n", "\n", "***\n", "\n", "Using JumpStart, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the `deploy_image_uri`, `deploy_source_uri`, and `model_uri` for the pre-trained model. To host the pre-trained model, we create an instance of [`sagemaker.model.Model`](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html) and deploy it. This may take a few minutes.\n", "***" ] }, { "cell_type": "code", "execution_count": null, "id": "90e7bf0e", "metadata": {}, "outputs": [], "source": [ "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n", "from sagemaker.model import Model\n", "from sagemaker.predictor import Predictor\n", "from sagemaker.utils import name_from_base\n", "\n", "\n", "endpoint_name = name_from_base(f\"jumpstart-example-infer-{model_id}\")\n", "\n", "inference_instance_type = \"ml.p2.xlarge\"\n", "\n", "# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.\n", "deploy_image_uri = image_uris.retrieve(\n", " region=None,\n", " framework=None, # automatically inferred from model_id\n", " image_scope=\"inference\",\n", " model_id=model_id,\n", " model_version=model_version,\n", " instance_type=inference_instance_type,\n", ")\n", "\n", "# Retrieve the inference script uri. This includes all dependencies and scripts for model loading, inference handling etc.\n", "deploy_source_uri = script_uris.retrieve(\n", " model_id=model_id, model_version=model_version, script_scope=\"inference\"\n", ")\n", "\n", "\n", "# Retrieve the model uri. This includes the pre-trained model and parameters.\n", "model_uri = model_uris.retrieve(\n", " model_id=model_id, model_version=model_version, model_scope=\"inference\"\n", ")\n", "\n", "\n", "# Create the SageMaker model instance\n", "model = Model(\n", " image_uri=deploy_image_uri,\n", " source_dir=deploy_source_uri,\n", " model_data=model_uri,\n", " entry_point=\"inference.py\", # entry point file in source_dir and present in deploy_source_uri\n", " role=aws_role,\n", " predictor_cls=Predictor,\n", " name=endpoint_name,\n", ")\n", "\n", "# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,\n", "# for being able to run inference through the sagemaker API.\n", "model_predictor = model.deploy(\n", " initial_instance_count=1,\n", " instance_type=inference_instance_type,\n", " predictor_cls=Predictor,\n", " endpoint_name=endpoint_name,\n", ")" ] }, { "cell_type": "markdown", "id": "d51c8bb3", "metadata": {}, "source": [ "### 4. Query endpoint and parse response\n", "\n", "---\n", "Input to the endpoint is any string of text dumped in json and encoded in `utf-8` format. Output of the endpoint is a `json` with summarized text.\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "c4ba6c2c", "metadata": {}, "outputs": [], "source": [ "def query(model_predictor, text):\n", " \"\"\"Query the model predictor.\"\"\"\n", "\n", " encoded_text = text.encode(\"utf-8\")\n", "\n", " query_response = model_predictor.predict(\n", " encoded_text,\n", " {\n", " \"ContentType\": \"application/x-text\",\n", " \"Accept\": \"application/json\",\n", " },\n", " )\n", " return query_response\n", "\n", "\n", "def parse_response(query_response):\n", " \"\"\"Parse response and return summary text.\"\"\"\n", "\n", " model_predictions = json.loads(query_response)\n", " translation_text = model_predictions[\"summary_text\"]\n", " return translation_text" ] }, { "cell_type": "markdown", "id": "e369a0c4", "metadata": {}, "source": [ "---\n", "Below, we put in some example input text. You can put in any text and the model will summarize the text.\n", "\n", "---" ] }, { "cell_type": "code", "execution_count": null, "id": "a1ed874e", "metadata": {}, "outputs": [], "source": [ "newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n", "\n", "input_text = \"The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.\"\n", "\n", "query_response = query(model_predictor, input_text)\n", "\n", "summary_text = parse_response(query_response)\n", "\n", "print(f\"Input text: {input_text}{newline}\" f\"Summary text: {bold}{summary_text}{unbold}{newline}\")" ] }, { "cell_type": "markdown", "id": "c5e74707", "metadata": {}, "source": [ "### 5. Clean up the endpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "8c92328b", "metadata": {}, "outputs": [], "source": [ "# Delete the SageMaker endpoint\n", "model_predictor.delete_model()\n", "model_predictor.delete_endpoint()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ec79d0ee", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart_text_summarization|Amazon_JumpStart_Text_Summarization.ipynb)\n" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science 3.0)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-310-v1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "pycharm": { "stem_cell": { "cell_type": "raw", "metadata": { "collapsed": false }, "source": [] } } }, "nbformat": 4, "nbformat_minor": 5 }